#include "Graphics/DXDevice.h"
#include "Graphics/DXUtilities.h"

#include <dxgi1_6.h>
#include <exception>

DXDevice::DXDevice(bool checkForRayTracingSupport)
{
	// 1. Enable Debug Layer in case we run with debug mode //
	DebugLayer();

	// 2. Create a factory // 
	ComPtr<IDXGIFactory4> dxgiFactory;
	unsigned int createFactoryFlags = 0;

#if defined(_DEBUG)
	createFactoryFlags = DXGI_CREATE_FACTORY_DEBUG;
#endif

	ThrowIfFailed(CreateDXGIFactory2(createFactoryFlags, IID_PPV_ARGS(&dxgiFactory)));

	// 3. Query the most suitable Adapter to create the Device with //
	ComPtr<IDXGIAdapter1> dxgiAdapter1;
	ComPtr<IDXGIAdapter4> dxgiAdapter4;

	SIZE_T maxDedicatedVideoMemory = 0;
	for(UINT i = 0; dxgiFactory->EnumAdapters1(i, &dxgiAdapter1) != DXGI_ERROR_NOT_FOUND; ++i)
	{
		DXGI_ADAPTER_DESC1 dxgiAdapterDesc1;
		dxgiAdapter1->GetDesc1(&dxgiAdapterDesc1);

		// Is adapter not a software GPU? Can it succesfully create a DirectX 12 device?
		// Does it have more dedicated video memory than the other adapters? If so, store it as the most capable adapter
		if((dxgiAdapterDesc1.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) == 0 &&
			SUCCEEDED(D3D12CreateDevice(dxgiAdapter1.Get(),
				D3D_FEATURE_LEVEL_12_1, __uuidof(ID3D12Device), nullptr)) &&
			dxgiAdapterDesc1.DedicatedVideoMemory > maxDedicatedVideoMemory)
		{
			maxDedicatedVideoMemory = dxgiAdapterDesc1.DedicatedVideoMemory;
			ThrowIfFailed(dxgiAdapter1.As(&dxgiAdapter4));
		}
	}

	// 4. With the chosen Adapter create the Device // 
	ThrowIfFailed(D3D12CreateDevice(dxgiAdapter4.Get(), D3D_FEATURE_LEVEL_12_1, IID_PPV_ARGS(&device)));

	// 5. Set up message severities for debug messages // 
	SetupMessageSeverities();

	// Optional: Check if the device supports DXR //
	if(checkForRayTracingSupport)
	{
		// Verify if Ray Tracing is supported //
		D3D12_FEATURE_DATA_D3D12_OPTIONS5 optionData = {};
		HRESULT result = device->CheckFeatureSupport(D3D12_FEATURE_D3D12_OPTIONS5, &optionData, sizeof(optionData));

		if(FAILED(result) || optionData.RaytracingTier < D3D12_RAYTRACING_TIER_1_0)
		{
			LOG(Log::MessageType::Error, "Device/Driver does NOT support (DX)Ray Tracing!");
			throw std::exception();
		}
	}
}

ComPtr<ID3D12Device5> DXDevice::Get()
{
	return device;
}

ID3D12Device5* DXDevice::GetAddress()
{
	return device.Get();
}

void DXDevice::DebugLayer()
{
#if defined (_DEBUG)
	ComPtr<ID3D12Debug> debugInterface;
	ThrowIfFailed(D3D12GetDebugInterface(IID_PPV_ARGS(&debugInterface)));
	debugInterface->EnableDebugLayer();
#endif
}

void DXDevice::SetupMessageSeverities()
{
#if defined (_DEBUG)
	ComPtr<ID3D12InfoQueue> pInfoQueue;
	if(SUCCEEDED(device.As(&pInfoQueue)))
	{
		pInfoQueue->SetBreakOnSeverity(D3D12_MESSAGE_SEVERITY_CORRUPTION, TRUE);
		pInfoQueue->SetBreakOnSeverity(D3D12_MESSAGE_SEVERITY_ERROR, TRUE);
		pInfoQueue->SetBreakOnSeverity(D3D12_MESSAGE_SEVERITY_WARNING, TRUE);

		// Suppress messages based on their severity level
		D3D12_MESSAGE_SEVERITY Severities[] =
		{
			D3D12_MESSAGE_SEVERITY_INFO
		};

		// Suppress individual messages by their ID
		D3D12_MESSAGE_ID DenyIds[] = {
			D3D12_MESSAGE_ID_CLEARRENDERTARGETVIEW_MISMATCHINGCLEARVALUE,
			D3D12_MESSAGE_ID_MAP_INVALID_NULLRANGE,
			D3D12_MESSAGE_ID_UNMAP_INVALID_NULLRANGE,
		};

		D3D12_INFO_QUEUE_FILTER NewFilter = {};
		NewFilter.DenyList.NumSeverities = _countof(Severities);
		NewFilter.DenyList.pSeverityList = Severities;
		NewFilter.DenyList.NumIDs = _countof(DenyIds);
		NewFilter.DenyList.pIDList = DenyIds;

		ThrowIfFailed(pInfoQueue->PushStorageFilter(&NewFilter));
	}
#endif
}